{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "K5zCsXptwfL9",
      "metadata": {
        "id": "K5zCsXptwfL9"
      },
      "source": [
        "Copyright 2023 Google LLC.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e0b13011",
      "metadata": {
        "id": "e0b13011"
      },
      "outputs": [],
      "source": [
        "import shutil\n",
        "\n",
        "import numpy as np\n",
        "import os\n",
        "from PIL import Image\n",
        "import sys\n",
        "from shutil import copyfile\n",
        "from pathlib import Path\n",
        "\n",
        "from diffusers.schedulers import LMSDiscreteScheduler\n",
        "from diffusers import StableDiffusionPipeline\n",
        "\n",
        "\n",
        "import torch\n",
        "\n",
        "import torchvision.transforms as transforms\n",
        "from transformers import CLIPProcessor, CLIPModel, AutoTokenizer\n",
        "\n",
        "import glob\n",
        "import argparse"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d88f5d8f",
      "metadata": {
        "id": "d88f5d8f"
      },
      "source": [
        "## Choose concept and seed"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7ed9b6fe",
      "metadata": {
        "id": "7ed9b6fe"
      },
      "outputs": [],
      "source": [
        "concept = 'corn'\n",
        "target_seed = 55\n",
        "folder = f'./{concept}'\n",
        "prompt = f'a photo of a '\n",
        "num_inference_steps = 25"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "41911d50",
      "metadata": {
        "id": "41911d50"
      },
      "source": [
        "## Load model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "06939ef6",
      "metadata": {
        "id": "06939ef6"
      },
      "outputs": [],
      "source": [
        "pipe = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1-base\")\n",
        "pipe.to(\"cuda\")\n",
        "pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)\n",
        "pipe.set_progress_bar_config(disable=True)\n",
        "pipe.tokenizer.add_tokens('\u003c\u003e')\n",
        "trained_id = pipe.tokenizer.convert_tokens_to_ids('\u003c\u003e')\n",
        "pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))\n",
        "_ = pipe.text_encoder.get_input_embeddings().weight.requires_grad_(False)\n",
        "\n",
        "\n",
        "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\").to('cuda')\n",
        "clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "\n",
        "clip_tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
        "\n",
        "transform_tensor = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ca5eaa50",
      "metadata": {
        "id": "ca5eaa50"
      },
      "source": [
        "## Auxiliary functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "c25bbe5b",
      "metadata": {
        "id": "c25bbe5b"
      },
      "outputs": [],
      "source": [
        "def clip_transform(image_tensor):\n",
        "    image_tensor = torch.nn.functional.interpolate(image_tensor, size=(224, 224), mode='bicubic',\n",
        "                                                   align_corners=False)\n",
        "    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],\n",
        "                                      std=[0.26862954, 0.26130258, 0.27577711])\n",
        "    image_tensor = normalize(image_tensor)\n",
        "    return image_tensor\n",
        "\n",
        "def load_alphas(alphas_projection, token_embeddings, seed, prompt):\n",
        "    alphas_copy = alphas_projection.clone()\n",
        "    # embeddings_mat = token_embeddings[dictionary]\n",
        "    embedding = torch.matmul(alphas_copy, token_embeddings)\n",
        "    embedding = torch.mul(embedding, 1 / embedding.norm())\n",
        "    embedding = torch.mul(embedding, avg_norm)\n",
        "    pipe.text_encoder.text_model.embeddings.token_embedding.weight[trained_id] = torch.nn.Parameter(\n",
        "        embedding)\n",
        "    generator = torch.Generator(\"cuda\").manual_seed(seed)\n",
        "    return pipe(prompt, guidance_scale=7.5,\n",
        "                generator=generator,\n",
        "                return_dict=False,\n",
        "                num_images_per_prompt=1,\n",
        "                num_inference_steps=num_inference_steps)[0][0]"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cc1e9ff6",
      "metadata": {
        "id": "cc1e9ff6"
      },
      "source": [
        "# Load decomposition from folder"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a5a80831",
      "metadata": {
        "id": "a5a80831"
      },
      "outputs": [],
      "source": [
        "concept_nu = concept.replace('_', ' ')\n",
        "concept_u = concept.replace(' ', '_')\n",
        "\n",
        "orig_embeddings = pipe.text_encoder.text_model.embeddings.token_embedding.weight.clone().detach()\n",
        "norms = [i.norm().item() for i in orig_embeddings]\n",
        "avg_norm = np.mean(norms)\n",
        "\n",
        "alphas_dict = torch.load(f'{folder}/output/best_alphas.pt').detach_().requires_grad_(False)\n",
        "\n",
        "dictionary = torch.load(f'{folder}/output/dictionary.pt')\n",
        "sorted_alphas, sorted_indices = torch.sort(alphas_dict, descending=True)\n",
        "alpha_ids = []\n",
        "num_alphas = 50\n",
        "for i, idx in enumerate(sorted_indices[:num_alphas]):\n",
        "    alpha_ids.append((i, pipe.tokenizer.decode([dictionary[idx]])))\n",
        "alphas = torch.zeros(orig_embeddings.shape[0]).cuda()\n",
        "top_word_idx = [dictionary[i] for i in sorted_indices[:num_alphas]]\n",
        "for i, index in enumerate(top_word_idx):\n",
        "    alphas[index] = alphas_dict[sorted_indices[i]]\n",
        "\n",
        "clip_concept_inputs = clip_tokenizer([concept_nu], padding=True, return_tensors=\"pt\").to('cuda')\n",
        "clip_concept_features = clip_model.get_text_features(**clip_concept_inputs)\n",
        "\n",
        "clip_text_inputs = clip_tokenizer([pipe.tokenizer.decode([x]) for x in top_word_idx], padding=True, return_tensors=\"pt\").to('cuda')\n",
        "clip_text_features = clip_model.get_text_features(**clip_text_inputs)\n",
        "clip_words_similarity = (torch.matmul(clip_text_features, clip_text_features.transpose(1, 0)) /\n",
        "                         torch.matmul(clip_text_features.norm(dim=1).unsqueeze(1),\n",
        "                                      clip_text_features.norm(dim=1).unsqueeze(0)))\n",
        "\n",
        "concept_words_similarity = torch.cosine_similarity(clip_concept_features, clip_text_features, axis=1)\n",
        "similar_words = (np.array(concept_words_similarity.detach().cpu()) \u003e 0.92).nonzero()[0]\n",
        "clip_words_similarity = (np.array(clip_words_similarity.detach().cpu()) \u003e 0.95)\n",
        "\n",
        "# Zero-out similar words\n",
        "for i in similar_words:\n",
        "    alphas[top_word_idx[i]] = 0"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5813ed86",
      "metadata": {
        "id": "5813ed86"
      },
      "source": [
        "### Visualize ground truth concept image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5748dcb1",
      "metadata": {
        "id": "5748dcb1",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "generator = torch.Generator(\"cuda\").manual_seed(target_seed)\n",
        "orig_image = pipe(f'a photo of a {concept}', guidance_scale=7.5,\n",
        "                generator=generator,\n",
        "                return_dict=False,\n",
        "                num_images_per_prompt=1,\n",
        "                num_inference_steps=num_inference_steps)[0][0]\n",
        "orig_image.resize((224,224))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bf62ce03",
      "metadata": {
        "id": "bf62ce03"
      },
      "source": [
        "### Visualize decomposition image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e405e39e",
      "metadata": {
        "id": "e405e39e"
      },
      "outputs": [],
      "source": [
        "image = load_alphas(alphas, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
        "image.resize((224,224))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "de48de44",
      "metadata": {
        "id": "de48de44"
      },
      "source": [
        "## Single-image decomposition code"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "29066be2",
      "metadata": {
        "id": "29066be2"
      },
      "source": [
        "### Iteratively remove features from the decomposition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "34945bd7",
      "metadata": {
        "id": "34945bd7",
        "scrolled": true
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "        final_alphas = alphas.clone()\n",
        "        target_clip = clip_processor(images=image, return_tensors=\"pt\")['pixel_values'].cuda()\n",
        "        target_clip = clip_model.get_image_features(target_clip)\n",
        "        next_indices = []\n",
        "        removed = True\n",
        "        saving_images = False\n",
        "        indices = np.arange(num_alphas)[::-1]\n",
        "\n",
        "        while removed:\n",
        "            removed = False\n",
        "            for idx in indices:\n",
        "                temp = final_alphas.clone()\n",
        "                temp[top_word_idx[idx]] = 0\n",
        "                # Also remove similar words\n",
        "                for similar_idx in clip_words_similarity[idx].nonzero()[0]:\n",
        "                    temp[top_word_idx[similar_idx]] = 0\n",
        "                image = load_alphas(temp, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
        "\n",
        "                curr_clip = clip_processor(images=image, return_tensors=\"pt\")['pixel_values'].cuda()\n",
        "                curr_clip = clip_model.get_image_features(curr_clip)\n",
        "                similarity = torch.cosine_similarity(target_clip, curr_clip).item()\n",
        "                if similarity \u003e 0.93:\n",
        "                    print(f\"removing token in idx: \", idx)\n",
        "                    final_alphas = temp.clone()\n",
        "                    removed = True\n",
        "                else:\n",
        "                    print(f\"similarity: {similarity} keeping token in idx: \", idx)\n",
        "                    next_indices.append(idx)\n",
        "            indices = next_indices\n",
        "            next_indices = []"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "87e5fe1a",
      "metadata": {
        "id": "87e5fe1a"
      },
      "source": [
        "### Visualize image after removing features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "f74f42e6",
      "metadata": {
        "id": "f74f42e6"
      },
      "outputs": [],
      "source": [
        "image_decomp = load_alphas(final_alphas, orig_embeddings, target_seed, f'{prompt} \u003c\u003e')\n",
        "image_decomp.resize((224,224))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6c922333",
      "metadata": {
        "id": "6c922333"
      },
      "source": [
        "### Visualize the remaining image features"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "02ee962b",
      "metadata": {
        "id": "02ee962b"
      },
      "outputs": [],
      "source": [
        "remaining_features = torch.nonzero(final_alphas).flatten()\n",
        "for feature in remaining_features:\n",
        "    print(\"feature: \", pipe.tokenizer.decode(feature))\n",
        "    generator = torch.Generator(\"cuda\").manual_seed(target_seed)\n",
        "    feature_visualization = pipe(f'a photo of a {pipe.tokenizer.decode(feature)}', guidance_scale=7.5,\n",
        "                    generator=generator,\n",
        "                    return_dict=False,\n",
        "                    num_images_per_prompt=1,\n",
        "                    num_inference_steps=num_inference_steps)[0][0]\n",
        "    display(feature_visualization.resize((224,224)))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
